1 Set up environment

1.1 Reset workspace and load libraries

This analysis uses ABCD Release 4.0

rm(list=ls())
gc()
library(tidyverse)
library(ggthemes)

1.2 Setting up paths

source(paste0(scriptfold,"stacking_gfactor_modelling/r_functions.R"))

set up parallel

# parallel for ubuntu
doParallel::registerDoParallel(cores=15)  

## this one works for ubuntu but slow
#library(doFuture)
#registerDoFuture()
#plan(multicore(workers = 30))

### parallel for windows

#library(doFuture)
#registerDoFuture()
#plan(multisession(workers = 30))

2 Loading the shap values

index_vec <- seq(1:21)

process_shap_results <- function(vec_input=1,time_input="baseline"){
file_name <- paste0("random_forest_",time_input,"_results_shap_",vec_input,".RDS")
output_list <- readRDS(paste0(scriptfold,"stacking_gfactor_modelling/rf_shap_results/",file_name))
return(output_list)
}

rf_shap_baseline <- map(index_vec, ~process_shap_results(vec_input=.))

### extract the site information

rf_shap_baseline_test_data <- map(rf_shap_baseline,"test_data")
rf_shap_baseline_train_data <- map(rf_shap_baseline,"train_data")
rf_shap_baseline_shap <- map(rf_shap_baseline,"model_shap")


site_extract <- function(data_input){
  site <- unique(data_input[["SITE_ID_L"]])
  return(site)
}


rf_shap_baseline_site <- map(rf_shap_baseline_test_data,~site_extract(data_input = .))%>% do.call(rbind,.)


rf_shap_followup <- map(index_vec, ~process_shap_results(vec_input=.,time_input="followup"))

rf_shap_followup_test_data <- map(rf_shap_followup,"test_data")
rf_shap_followup_train_data <- map(rf_shap_followup,"train_data")

rf_shap_followup_shap <- map(rf_shap_followup,"model_shap")



rf_shap_followup_site <- map(rf_shap_followup_test_data,~site_extract(data_input = .))%>% do.call(rbind,.)

Setting up the continuous color palette:

brain_plot_color <- c("#240105","#380208","#450109","#52010b","#69010e","#8a0313","#990315","#b00419","#B2182B","#D6604D","#F4A582","#FDDBC7","#D1E5F0","#92C5DE","#4393C3","#2166AC","#0358ad","#02509e","#01488f","#013870","#002c59","#012a54","#012142","#001933")

Setting up vectors to choose the right part of the output

subj_info <- c("SUBJECTKEY","SITE_ID_L","EVENTNAME")

subj_resp_noid <- c("gfactor","SITE_ID_L","EVENTNAME")

Load in the names of plotting titles for all the modalities.

plotting_names <- readxl::read_excel(paste0(scriptfold,"Common_psy_gene_brain_all/CommonalityPlotingNames.xlsx"))


plotting_names_tidy <- plotting_names %>%
                       mutate(tidied_names = str_remove_all(Original_name,"data"))

plotting_names_tidy <- plotting_names_tidy %>%
                       mutate(tidied_names = str_remove_all(tidied_names,"ROI\\_"))%>%
                       select(-Original_name)%>%
                       rename(modality_names = tidied_names)

3 process data and plot shapley summary plot

# The input of this function should have exactly 4 variables
# SUBJECTKEY that is subject ID
# The names of modality ie MID, smri, rsmri, nback, sst and dti
# shapley values
# feature values
shapley_summary_plot <- function(data_input, list_val_input){
shap_plot <-   data_input%>%
 # filter(var_names %in% vars_keep)%>%
  mutate(plotting_name = fct_reorder(plotting_name, shapley_values,.fun = "max"))%>%
ggplot(aes(x = plotting_name, y = shapley_values, color = feature_values)) +
    coord_flip(ylim = range(data_input$shapley_values)*1.1) +
    geom_hline(yintercept = 0) + # the y-axis beneath
    # sina plot:
    ggforce::geom_sina(method = "counts", maxwidth = 0.7, alpha = 0.7)+
    scale_color_gradientn(colours=brain_plot_color ,na.value ="gray50",
                         breaks=unlist(list_val_input), labels=c('min','med','max'),
guide = guide_colorbar(direction = "horizontal",barwidth = 12, barheight = 0.3))+
    theme_bw() +
    theme(axis.line.y = element_blank(),
          axis.ticks.y = element_blank(), # remove axis line
          legend.position="bottom",
          legend.direction = 'horizontal',
          legend.title=element_text(size=20),
          legend.text=element_text(size=20),
          legend.box = "horizontal",
          axis.title.x= element_text(size = 20),
          axis.text.y = element_text(size = 15),
          axis.text.x = element_text(size = 20)) +
    labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value  ")
return(shap_plot)
}


### this function read in the shap output file and plot the summary plot by each site.
shapley_data_process_plotting <- function(shapley_input, data_input){
  ##processing shapley output
  shapley_large_wider <- shapley_input%>%tibble::as_tibble()%>%
    dplyr::select(ends_with("large"))%>%
    mutate(SUBJECTKEY=data_input$SUBJECTKEY)
 
  
  shapley_large<-  shapley_large_wider%>%
    pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "shapley_values") %>%
    mutate(modality_names = str_remove_all(.data[["modality_names"]],"large"))
  ### get the data together with plotting names
  shapley_large_with_names <-  left_join(shapley_large,plotting_names_tidy,by ="modality_names")
  
  
  data_large <- data_input%>%tibble::as_tibble()%>%
    dplyr::select(ends_with("large"),"SUBJECTKEY")%>%
    pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "feature_values")%>%
    mutate(modality_names = str_remove_all(.data[["modality_names"]],"large"))
 
  
  
  
   
  
  shapley_small_wider <- shapley_input%>%tibble::as_tibble()%>%
    dplyr::select(ends_with("small"))%>%
    mutate(SUBJECTKEY=data_input$SUBJECTKEY)
  
  
  shapley_small <-  shapley_small_wider%>%
    pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "shapley_values")%>%
    mutate(modality_names = str_remove_all(.data[["modality_names"]],"small"))
   ### get the data together with plotting names
  shapley_small_with_names <-  left_join(shapley_small,plotting_names_tidy,by ="modality_names")
  
  
  
  data_small <- data_input%>%tibble::as_tibble()%>%
    dplyr::select(ends_with("small"),"SUBJECTKEY")%>%
    pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "feature_values")%>%
    mutate(modality_names = str_remove_all(.data[["modality_names"]],"small"))
  
  shapley_all <- bind_rows(shapley_large_with_names,shapley_small_with_names)
  data_all <- bind_rows(data_small,data_small)
  shapley_summary_tibble <- plyr::join_all(list(shapley_all,data_all), 
                               by=c("SUBJECTKEY","modality_names"),type="full")%>%
                              distinct()
  
  

    ##shapley summary plot
    shapley_summary_tibble_NA <- shapley_summary_tibble %>%
                          naniar::replace_with_na(replace = list(feature_values = c(-1000,1000.000000000)))
   shapley_summary_tibble_0 <- shapley_summary_tibble
      shapley_summary_tibble_0$feature_values[which(shapley_summary_tibble_0$feature_values == 1000)] <- 0
      shapley_summary_tibble_0$feature_values[which(shapley_summary_tibble_0$feature_values == -1000)] <- 0

    
   # list of functions to calculate the values where you want your breaks
    myfuns <- list(min, median, max)
    # use this list to make a list of your breaks
    list_val <- lapply(myfuns, function(f) f(shapley_summary_tibble$feature_values)) 
    list_val_0 <- lapply(myfuns, function(f) f(shapley_summary_tibble_0$feature_values))
    
  shapley_summary_plot_1k <- shapley_summary_plot(data_input = shapley_summary_tibble,
                                               list_val_input = list_val)  
  
    shapley_summary_plot_na <- shapley_summary_plot(data_input = shapley_summary_tibble_NA,
                                                    list_val_input = list_val_0)  

### plot the sum of the absolute shap values  
    shapley_wider_names <- colnames(shapley_large_wider)
shapley_wider_names_cleaned <-  shapley_wider_names %>% str_remove_all("large")   

shapley_large_wider_new_names <-shapley_large_wider
    
names(shapley_large_wider_new_names) <- shapley_wider_names_cleaned

shapley_small_wider_new_names <-shapley_small_wider
    
names(shapley_small_wider_new_names) <- shapley_wider_names_cleaned

shapley_vi <- bind_rows(shapley_small_wider_new_names,shapley_large_wider_new_names)%>%
              select(-"SUBJECTKEY") %>%
              abs()%>% colSums()
              
shapley_vi_tibble <- tibble(modality_names = names(shapley_vi),
                            shapley_vi = shapley_vi)  


shapley_vi_tibble_with_names <- left_join(shapley_vi_tibble,plotting_names_tidy,by = "modality_names")

shapley_vi_plot <- shapley_vi_tibble_with_names %>%  
  mutate(plotting_name = fct_reorder(plotting_name, shapley_vi,.fun = "max"))%>%
  ggplot(aes(x = plotting_name, y = shapley_vi))+
                      geom_bar(stat = "identity")+
                      coord_flip()+
  labs(y = "|SHAP| variable importance", x = "") +
    theme(axis.title.x= element_text(size = 20),
          axis.title.y= element_text(size = 20),
          axis.text.y = element_text(size = 15),
          axis.text.x = element_text(size = 20))
                      

return(list(summary_plot_na = shapley_summary_plot_na,
            summary_plot_1k = shapley_summary_plot_1k, 
            importance_plot = shapley_vi_plot,
            shapley_vi_tibble_with_names=shapley_vi_tibble_with_names
            ))
}

trial code: one output plot for one site

plots_one_site <-shapley_data_process_plotting(shapley_input=rf_shap_baseline_shap[[1]],
                                               data_input=rf_shap_baseline_train_data[[1]]) 



plots_one_site
## $summary_plot_na

## 
## $summary_plot_1k

## 
## $importance_plot

## 
## $shapley_vi_tibble_with_names
## # A tibble: 45 × 3
##    modality_names      shapley_vi plotting_name                
##    <chr>                    <dbl> <chr>                        
##  1 DTI_                     218.  DTI                          
##  2 rsmri_within_avg_        352.  rsfMRI cortical FC           
##  3 smri_T2_mean_total_       89.0 T2 summations                
##  4 smri_T1_mean_total_       45.2 T1 summations                
##  5 Normalised_T2_           121.  T2 normalised intensity      
##  6 Avg_T2_Gray_             310.  T2 gray matter avg intensity 
##  7 Avg_T2_White_            171.  T2 white matter avg intensity
##  8 Normalised_T1_           171.  T1 normalised intensity      
##  9 Avg_T1_Gray_             413.  T1 gray matter avg intensity 
## 10 Avg_T1_White_            263.  T1 white matter avg intensity
## # ℹ 35 more rows

Map the output across all sites

shap_plots_baseline <- map2(.x =rf_shap_baseline_shap,
                            .y=rf_shap_baseline_train_data,
                            ~shapley_data_process_plotting(shapley_input=.x,data_input=.y))

shap_plots_followup <- map2(.x =rf_shap_followup_shap,
                            .y=rf_shap_followup_train_data,
                            ~shapley_data_process_plotting(shapley_input=.x,data_input=.y))

4 Get the sum plot of all the absolute shapley values across all sites

Extract the sum values

Compute the means across sites and sd across sites.

shap_baseline_sum <- map(shap_plots_baseline,"shapley_vi_tibble_with_names")
shap_followup_sum <- map(shap_plots_followup,"shapley_vi_tibble_with_names")

names(shap_baseline_sum) <- rf_shap_baseline_site

names(shap_followup_sum) <- rf_shap_followup_site

shap_sum_change_names <- function(data_input,site_input){
  names(data_input) <-  c("modality_names", site_input,    "plotting_name" )
  return(data_input)
}
### change names according to sites
shap_baseline_sum_site <- map2(.x =shap_baseline_sum,
                               .y= rf_shap_baseline_site,~shap_sum_change_names(data_input = .x,site_input = .y))

shap_baseline_cross_site <- plyr::join_all(shap_baseline_sum_site, by = c("modality_names","plotting_name"))
### compute row mean and sd

shap_baseline_cross_site_num <- shap_baseline_cross_site%>% 
                                select(-all_of(c("modality_names","plotting_name" )))
mean_sum_shap_baseline <- rowMeans(shap_baseline_cross_site_num)

sd_sum_shap_baseline <- apply(shap_baseline_cross_site_num, 1, sd)

shap_baseline_cross_site <- shap_baseline_cross_site %>%
                            mutate(mean_cross_site = mean_sum_shap_baseline, sd_cross_site = sd_sum_shap_baseline)

### change names according to sites
shap_followup_sum_site <- map2(.x =shap_followup_sum,
                               .y= rf_shap_followup_site,~shap_sum_change_names(data_input = .x,site_input = .y))

shap_followup_cross_site <- plyr::join_all(shap_followup_sum_site, by = c("modality_names","plotting_name"))
### compute row mean and sd

shap_followup_cross_site_num <- shap_followup_cross_site%>% 
                                select(-all_of(c("modality_names","plotting_name" )))
mean_sum_shap_followup <- rowMeans(shap_followup_cross_site_num)

sd_sum_shap_followup <- apply(shap_followup_cross_site_num, 1, sd)

shap_followup_cross_site <- shap_followup_cross_site %>%
                            mutate(mean_cross_site = mean_sum_shap_followup, sd_cross_site = sd_sum_shap_followup)

Plotting the baseline and followup shap values individually based on the magnitude.

### function for summary plotting

shap_sum_cross_site <- function(data_input,time_input = "baseline"){
  data_select <- data_input %>% 
                 select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))
 
   shapley_vi_plot <- data_select %>%  
  mutate(plotting_name = fct_reorder(plotting_name, mean_cross_site,.fun = "max"))%>%
  ggplot(aes(x = plotting_name, y = mean_cross_site))+
                      geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name, 
                   ymin=mean_cross_site-sd_cross_site, 
                   ymax=mean_cross_site+sd_cross_site),
               width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
      coord_flip()+
  labs(y =paste0( "|SHAP| variable importance at ", time_input) , x = "") +
     theme_few()+
    theme(axis.title.x= element_text(size = 20),
          axis.title.y= element_text(size = 20),
          axis.text.y = element_text(size = 15),
          axis.text.x = element_text(size = 15))
  return(shapley_vi_plot)
}

shap_sum_plot_baseline <- shap_sum_cross_site(data_input=shap_baseline_cross_site)

shap_sum_plot_baseline

shap_sum_plot_followup <- shap_sum_cross_site(data_input=shap_followup_cross_site, time_input = "followup")

shap_sum_plot_followup

Plot the summary plots. The order of baseline is based on its magnitude and the followup is based on the order of baseline.

shap_baseline_cross_site_select <- shap_baseline_cross_site%>% 
                 select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))

shapley_vi_plot_baseline <- shap_baseline_cross_site_select %>%  
  mutate(plotting_name = fct_reorder(plotting_name, mean_cross_site,.fun = "max"))%>%
  ggplot(aes(x = plotting_name, y = mean_cross_site))+
                      geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name, 
                   ymin=mean_cross_site-sd_cross_site, 
                   ymax=mean_cross_site+sd_cross_site),
               width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
      coord_flip()+
  labs(y ="" , x = "") +
     theme_few()+
    theme(axis.title.x= element_text(size = 20),
          axis.title.y= element_text(size = 20),
          axis.text.y = element_text(size = 15),
          axis.text.x = element_text(angle = 90,size = 15,vjust=-0.2))

#shapley_vi_plot_baseline
# get the order of sets of brain scan features in baseline based on the magnitude
shap_baseline_cross_site_select_reordered <- shap_baseline_cross_site_select %>% 
                                                      arrange(mean_cross_site)

shap_baseline_cross_site_select_reordered_names <- shap_baseline_cross_site_select_reordered$plotting_name


shap_followup_cross_site_select <- shap_followup_cross_site%>% 
                 select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))


### reorder followup based on baseline reordered as described above
shap_followup_cross_site_select_reordered <- shap_followup_cross_site_select %>%
                  mutate(plotting_name = as.factor(plotting_name))%>%
                  mutate(plotting_name = factor(plotting_name,
                                          levels =shap_baseline_cross_site_select_reordered_names))


shapley_vi_plot_followup <- shap_followup_cross_site_select_reordered %>%  
  ggplot(aes(x = plotting_name, y = mean_cross_site))+
                      geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name, 
                   ymin=mean_cross_site-sd_cross_site, 
                   ymax=mean_cross_site+sd_cross_site),
               width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
      coord_flip()+
  labs(y ="" , x = "") +
     theme_few()+
    theme(axis.title.x= element_text(size = 20),
          axis.title.y= element_blank(),
          axis.text.y = element_blank(),
          axis.text.x = element_text(angle = 90,size = 20))

align figures

#shapley_vi_plot_followup

shapley_vi_label_baseline <- shapley_vi_plot_baseline%>%
                            ggpubr::annotate_figure(top = ggpubr::text_grob("Baseline",size=20,hjust=-1))

shapley_vi_label_followup <- shapley_vi_plot_followup%>%
                            ggpubr::annotate_figure(top = ggpubr::text_grob("Followup",size=20,hjust=1))


shapley_vi_plot_aligned <- ggpubr::ggarrange(shapley_vi_label_baseline, 
                                             shapley_vi_label_followup, widths = c(3, 1),heights = c(1,1),
          ncol = 2, nrow = 1
          #,common.legend=TRUE,legend = "top"
          )




labelled_shapley_vi_plot <- ggpubr::annotate_figure(shapley_vi_plot_aligned,
                          left= ggpubr::text_grob("Sets of Features from the Brain",size=25,rot=90),
                        bottom = ggpubr::text_grob("Mean |SHAP| Values",size=25,hjust = -0.3),
                        top = ggpubr::text_grob("Feature importance of Each Set of Features \nfrom the Brain in Predicting Cognitive Abilities",size=30, face = "bold")) 

labelled_shapley_vi_plot